{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from patsy import ContrastMatrix\n",
    "from sklearn.model_selection import train_test_split\n",
    "import statsmodels.api as sm\n",
    "from statsmodels.graphics.mosaicplot import mosaic\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('./data/adult.data')\n",
    "cols = ['workclass', 'sex', 'age', 'education_num', 'capital_gain',\n",
    "        'capital_loss', 'hours_per_week', 'label']\n",
    "data = data[cols]\n",
    "data['label_code'] = pd.Categorical(data.label).codes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>label</th>\n",
       "      <th>&lt;=50K</th>\n",
       "      <th>&gt;50K</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>sex</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Female</th>\n",
       "      <td>9592</td>\n",
       "      <td>1179</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Male</th>\n",
       "      <td>15128</td>\n",
       "      <td>6662</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "label     <=50K   >50K\n",
       "sex                   \n",
       " Female    9592   1179\n",
       " Male     15128   6662"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算sex, label交叉报表\n",
    "cross1 = pd.crosstab(data['sex'], data['label'])\n",
    "cross1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/OQEPoAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPoUlEQVR4nO3de1SV9Z7H8c8GtoCAFxBBxCQx85KaM6iJmibacDLPxKikRmealp3GcqpJqzmtWseZbGrmRKe7TTWnTnkqSDIrx8I0ylsqdijMvOGNIgkBExCIyzN/uNwLDurxAt8N2/drLdeK5/fsp99ja717ePb+7cflOI4AADb8vD0BALiYEF0AMER0AcAQ0QUAQ0QXAAwFnGmwW3gPp1dsnNFU0FpcLqmzWzpeJ/HhFMDezvxtRxzHiTzV2Bmj2ys2Tn9Ymds2s0KbGtNH2lTo7VkAF6fES1wHTzfG7QUAMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBL3jlyUWaPvZSz89rPshU4iUuHa+qPOX+X27K0bOLF1pND22I6AJe0rV7D337Va4kaf0nH+iywcO9PCNYILqAl1xz3QzlrMpSbU216n6uVWiXbpKkgp35umPmBN12wxilPzy/xeu+yPlI86aP169TEpW94i3jWeNCEV3AS/oNGKL9u7/Rpk9XafSEv/Nsj43rr+czc/Tye5tUXFSowv17PGOO4+jVpx/RM2+t0ZJl65T12nNqaGjwxvRxngK8PQHgYtZv4FC98cLjSv/j/+mjd9+QJBUd2q9nFy9QTfVxFR3apyPFRZ79y0tLVLh/t+5Ju1aSVHnsqI6WliiiZ7RX5o9zR3QBL0pOSZMkdQvv4dm2fOkSzb5tgUaOn6z7b/2lHMfxjHUL76G+8QP11NJsuTt1Un1dnQLcbvN54/wRXcCL4i4bpH++/9Fm28ZOnqanFt2tvv0HqrGxsdmYn5+f/vGuh3T3TVPk5+enbuGRWrwk03LKuECupv8X/UuDhiU4f1iZazgdtJYxfaRNhd6eBXBxSrzEtc1xnIRTjfFGGgAYIroAYIjoAu3UK08uUtqUobozdaKeeWSBJKmhoUGPLrxV86aP11OL7vHse+vUE7/JlpeWaO4vR2v3N3lemDHOBtEFvKD6eNVZ7Tfvgcf0fGaO7no4XZK0Yc2H6hEVoyVZ61R9vEr52zZ59q2qOKbf3Jaif3koXQOGXNkW00YrILqAkfr6en328Xt6YO4NeuXJ357Va15Kf1h3zJyg3A1rJUn5uRs16uoTn9G9amKy8nM3SJJqa2v04O3T9av5D2r4qHFtcwJoFXxkDGhjPxQe0Luvv6Bv83N11YRk3ffoEvWI6iVJWpX1hj7M+N9m+182+Erds+gppd56l+beu0hlJcW6a85kvbpymyp+KldIaBdJUmhYVx07WiZJKi46pMDAII0aP8X03HDuiC7Qxr79Olfr13ygf7j5Dl17wxx17R7hGfvF9Jv1i+k3n/J1XbqFS5LCI6MU13+QfvzhO4V16aaqymOSpMqKnzz7XHLpAE25YY4eu3+uHv79H9v4jHAhuL0AtLFJU2fo9Y++UniPKC2+9xY9NC/Vc7tgVdYbujN1YrM/J98gq6o4Edea6uM6WLBTPaJ66YqEROWu/0SStPmzjzU0Yazn3zP7tnsVHBKqV9LP7tYFvIMrXcCAu1MnJU1LVdK0VB3+/pC++fILSWe+0n3u0ftUsCtfjQ0N+tX8BxUYFKyxSddr3cfvad708bpsyAgN/dsxzV7zr//+jP5t7g1a+c5rmjrzlrY+LZwHVqT5KFakAd7DijQAaCeILgAYIroAYMgnott0ueTdN13bqse+M3XiaR8WCADnymc+vTDvgcc0dvL13p4GAJyRz0S3Kcdx9Pvf3qV9u7bLz89fDz35mnr2itXsSYM0ePgo7dr+pW66/T6tW71Ch/bt1v3/+aKGjRyrp//jXu3K36bammo98PhLzdav19bU6LEH5upIcZE6dw7Vb59eqpCwLt47SQAdks9Ed8l//UZvvvSEBg8fpStHX62wrt31XMan+ubPm/XGC49rwSPPqezHw1qw+HkdKS7S/NSJemf9Ph3at0tvvZSuYSPH6vb7FisouLN2bf+z3vyf32nRM3/yHP+Dt19RQuIkXX/jrfrk/QytePMlzbl9oRfPGEBH5DPRbXp7YemS/9ZnHy1X3ubP5TiOesb0kSTFXNJPnUNC1SMqRrGXXqbAoCBFRvdWxU/lkqQ/vfg7z2of/4DmfzX79+zQt19t1aqs11VfV6fho8Ybnh0AX+Ez0W2qb/xAJV2fqn+6+2FJUn1dnSTJ5XJ59mn6z47j6KfyUm1dt1ovvrteO7/epmcXL2hxzCv+Zoxn9dDJYwLAufCJTy/8pXFTpumn8lLNv/EazZ81SauyXv+rrwnr2l1duoXrztSJWrvynRbjfz/n19q6brXmz5qk+bMmafPn2W0xdQA+jmXAPoplwID3sAwYANoJogsAhi7q6PLgPwDWfDa6PPgPQHvkU9HlwX8A2juf+JwuD/4D0FH4RHR58B+AjsInbi/w4D8AHYVPXOlKPPgPQMfAijQfxYo0wHtYkQYA7QTRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0QUAQ0QXAAwRXQAwRHQBwBDRBQBDRBcADBFdADBEdAHAENEFAENEFwAMEV0AMER0AcAQ0fVhbv7rAu1OgLcngLbz4cuLVFlZ6e1pAGiCayEfRnCB9ofoAoAhogsAhoguABgiugBgiOgCgCGiCwCGiC4AGCK6AGCI6AKAIaILAIaILgAYIroAYIjoAoAhogsAhoguABgiugBgiOgCgCGiCwCGiC4AGCK6AGCI6AKAIaILAIaILuAFGzdu1Msvv+z5edeuXUpPT9fPP/98yv0LCwuVk5NjNDu0JaILeElwcLAOHz4sSSooKFBkZKSXZwQLAd6eAHCxGjBggPbs2aOIiAg1NDQoMDBQklRSUqK1a9eqoaFBUVFRSkpKava6/fv3a/PmzWpsbNSIESM0aNAgb0wf54krXcBLIiIidOTIER04cEBxcXGe7d26dVNqaqrmzJmjiooKlZeXe8Ycx9EXX3yhmTNnatasWcrLy1NjY6MXZo/zxZUu4EWRkZHasmWLUlJStGPHDknSsWPHlJOTo/r6eh09elSVlZWe/aurq1VeXq5ly5ZJkmpra1VdXa2QkBCvzB/njugCXnTy1kDnzp092/Ly8pSQkKC+fftq+fLlzfYPDg5WeHi4ZsyYIX9/fzU0NMjf3990zrgwRBfwooiICI0bN67Ztvj4eH366acKDw9vsb/L5dJVV12lZcuWyeVyKTg4WNOmTbOaLloB0QW8IDExscW2G2+8UZIUFxenW265pcV4nz59PONN7wGjY+GNNAAwRHQBwBC3F4B2ZOPGjdqzZ4+CgoIUFRWliRMnqrGxUdnZ2Tp69KiioqJ0zTXXSJKWLl2qtLQ0HT9+XMuXL9eUKVPUs2dPL58B/hqiCxirq6uT2+0+7fi4ceMUHx/v+Xnfvn0KDQ1VcnKysrOzVVRUpJiYGEknPjK2YsUKTZgwgeB2ENxeAIzt3btXmZmZysvLU21tbYvxDRs2KCMjQ4cOHZIkFRUVed44i4uL0/fffy9Jqq+v1/vvv6/Ro0crNjbWbP64MFzpAsYGDRqk/v37a/fu3Vq5cqWCg4M1fPhwxcTEaMSIEUpMTFRVVZWWLVumtLQ01dTUqFOnTpKkwMBA1dTUSJIqKioUEBCgvn37evN0cI640gW8wO126/LLL9eQIUNUXl6uvXv3Sjqx+EGSQkJCFB4eroqKCgUGBnq+fay2tlZBQUGSpO7du2vgwIHKzs72zkngvBBdwFhZWZnWrFmjrKwsVVVVKSUlRVdffbUkeW431NXVqaysTCEhIYqJidHBgwclSQcOHFDv3r09x0pISJDb7daGDRvsTwTnhdsLgLGqqioNGTJE0dHRLcY+//xzlZSUyHEcjR49Wm63W/Hx8SooKNDbb7+tnj17et5EO2nSpElasWKFtm/friuuuMLqNHCezhhdl8tqGmhNbn5/addOriw7lSlTprTY5ufnp+Tk5Bbb09LSPOMpKSmtN0G0qTNG92hJkbKeXmg1F7SS0NBQJSxa5O1pADiFM14T8T2dHVPTrwIE0L7wiygAGPKJN9KaLp309/fXjBkzWu3YGRkZSklJ8XxOEgAuhE9EV2q5dBIA2iOfiW5TjuNo7dq1Ki0tlcvlUnJyssLCwvTqq68qOjpaP/74oxISElRQUKDy8nJNnjxZvXv3Vk5OjoqLi1VfX9/iy0Pq6+uVnZ2tyspKud1uXXfddZ4HCQLA2fKZ6K5fv165ubmKjo5WbGysgoKClJqaqh9++EFbtmxRUlKSqqqqlJSUpMrKSmVmZmru3LkqKyvTtm3b1Lt3b40dO1Zut1vFxcXaunWrpk6d6jl+fn6++vTpo6FDh2rnzp36+uuvNXLkSC+eMYCOyGei2/T2wpYtW7R371599913kqSwsDBJUteuXdWpUyeFhoaqe/fuCggIUGhoqGct+9atWz1fMuLn1/w9xtLSUh0+fFg7duxQY2Njs1VBAHC2fCa6TYWHh2vAgAEaM2aMJKmhoUHSiedLnU51dbUOHjyo2bNnq7i4WDk5OS2OGRMTo8GDBzc7JgCcC5+Mbnx8vAoLC5WZmSnpxLc6DR069IyvCQoKUlBQkDIyMtSrV68W48OGDdPq1au1fft2SSfWvPfr16/1Jw/Ap7kcxzntYHR0tHNyqSE6lieeeEILF7KaEPCG9PT0bY7jJJxqjMURAGCI6AKAIZ+8p3s6PPQPgLf5XHR56B+A9sznbi/w0D8A7ZnPXeny0D8A7ZnPXelKPPQPQPvlc9HloX8A2jOfu73AQ/8AtGc+F10e+gegPfO52wsA0J4RXQAwRHQBwNAZv2XM5XKVSDpoNx0A8Al9HceJPNXAGaMLAGhd3F4AAENEFwAMEV0AMER0AcAQ0QUAQ/8PCsx/B245jzcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 将交叉报表图形化\n",
    "props = lambda key: {'color': '0.45'} if ' >50K' in key else {'color': '#C6E2FF'}\n",
    "mosaic(cross1[[' >50K', ' <=50K']].stack(), properties=props, axes_label=False)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将数据分为训练集和测试集\n",
    "train_set, test_set = train_test_split(data, test_size=0.2, random_state=2310)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.409236\n",
      "         Iterations 8\n",
      "                           Logit Regression Results                           \n",
      "==============================================================================\n",
      "Dep. Variable:             label_code   No. Observations:                26048\n",
      "Model:                          Logit   Df Residuals:                    26042\n",
      "Method:                           MLE   Df Model:                            5\n",
      "Date:                Mon, 06 Nov 2023   Pseudo R-squ.:                  0.2582\n",
      "Time:                        20:42:17   Log-Likelihood:                -10660.\n",
      "converged:                       True   LL-Null:                       -14370.\n",
      "Covariance Type:            nonrobust   LLR p-value:                     0.000\n",
      "===================================================================================\n",
      "                      coef    std err          z      P>|z|      [0.025      0.975]\n",
      "-----------------------------------------------------------------------------------\n",
      "Intercept          -7.1944      0.112    -64.011      0.000      -7.415      -6.974\n",
      "C(sex)[T. Male]     1.2221      0.044     27.602      0.000       1.135       1.309\n",
      "education_num       0.3329      0.008     43.145      0.000       0.318       0.348\n",
      "capital_gain        0.0003   1.08e-05     31.073      0.000       0.000       0.000\n",
      "capital_loss        0.0007    3.6e-05     20.798      0.000       0.001       0.001\n",
      "hours_per_week      0.0308      0.001     20.857      0.000       0.028       0.034\n",
      "===================================================================================\n"
     ]
    }
   ],
   "source": [
    "# 加入sex，搭建逻辑回归模型，并训练模型\n",
    "formula = 'label_code ~ C(sex) + education_num + capital_gain + capital_loss + hours_per_week'\n",
    "model = sm.Logit.from_formula(formula, data=train_set)\n",
    "re = model.fit()\n",
    "print(re.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Maximum number of iterations has been exceeded.\n",
      "         Current function value: 0.422077\n",
      "         Iterations: 35\n",
      "                           Logit Regression Results                           \n",
      "==============================================================================\n",
      "Dep. Variable:             label_code   No. Observations:                26048\n",
      "Model:                          Logit   Df Residuals:                    26035\n",
      "Method:                           MLE   Df Model:                           12\n",
      "Date:                Mon, 06 Nov 2023   Pseudo R-squ.:                  0.2349\n",
      "Time:                        20:42:17   Log-Likelihood:                -10994.\n",
      "converged:                      False   LL-Null:                       -14370.\n",
      "Covariance Type:            nonrobust   LLR p-value:                     0.000\n",
      "=====================================================================================================\n",
      "                                        coef    std err          z      P>|z|      [0.025      0.975]\n",
      "-----------------------------------------------------------------------------------------------------\n",
      "Intercept                            -6.8672      0.138    -49.714      0.000      -7.138      -6.596\n",
      "C(workclass)[T. Federal-gov]          1.1092      0.129      8.616      0.000       0.857       1.362\n",
      "C(workclass)[T. Local-gov]            0.5202      0.116      4.491      0.000       0.293       0.747\n",
      "C(workclass)[T. Never-worked]       -26.5153   1.14e+06  -2.32e-05      1.000   -2.24e+06    2.24e+06\n",
      "C(workclass)[T. Private]              0.5357      0.100      5.376      0.000       0.340       0.731\n",
      "C(workclass)[T. Self-emp-inc]         1.4538      0.127     11.462      0.000       1.205       1.702\n",
      "C(workclass)[T. Self-emp-not-inc]     0.5330      0.115      4.646      0.000       0.308       0.758\n",
      "C(workclass)[T. State-gov]            0.3562      0.128      2.787      0.005       0.106       0.607\n",
      "C(workclass)[T. Without-pay]        -13.9687    945.497     -0.015      0.988   -1867.108    1839.171\n",
      "education_num                         0.3167      0.008     40.970      0.000       0.302       0.332\n",
      "capital_gain                          0.0003   1.08e-05     31.211      0.000       0.000       0.000\n",
      "capital_loss                          0.0008   3.54e-05     21.500      0.000       0.001       0.001\n",
      "hours_per_week                        0.0354      0.001     23.797      0.000       0.032       0.038\n",
      "=====================================================================================================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/tgbaggio/opt/anaconda3/lib/python3.8/site-packages/statsmodels/base/model.py:607: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals\n",
      "  warnings.warn(\"Maximum Likelihood optimization failed to \"\n"
     ]
    }
   ],
   "source": [
    "# 加入workclass，搭建逻辑回归模型，并训练模型\n",
    "formula = 'label_code ~ C(workclass) + education_num + capital_gain + capital_loss + hours_per_week'\n",
    "model = sm.Logit.from_formula(formula, data=train_set)\n",
    "re = model.fit()\n",
    "print(re.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.405935\n",
      "         Iterations 8\n",
      "                           Logit Regression Results                           \n",
      "==============================================================================\n",
      "Dep. Variable:             label_code   No. Observations:                26048\n",
      "Model:                          Logit   Df Residuals:                    26036\n",
      "Method:                           MLE   Df Model:                           11\n",
      "Date:                Mon, 06 Nov 2023   Pseudo R-squ.:                  0.2642\n",
      "Time:                        20:42:18   Log-Likelihood:                -10574.\n",
      "converged:                       True   LL-Null:                       -14370.\n",
      "Covariance Type:            nonrobust   LLR p-value:                     0.000\n",
      "=========================================================================================================================\n",
      "                                                            coef    std err          z      P>|z|      [0.025      0.975]\n",
      "-------------------------------------------------------------------------------------------------------------------------\n",
      "Intercept                                                -7.5784      0.145    -52.287      0.000      -7.862      -7.294\n",
      "C(workclass, contrast_mat, levels=l) State-gov            0.3520      0.130      2.699      0.007       0.096       0.608\n",
      "C(workclass, contrast_mat, levels=l) Self-emp-not-inc     0.3687      0.116      3.165      0.002       0.140       0.597\n",
      "C(workclass, contrast_mat, levels=l) Private              0.5058      0.101      4.984      0.000       0.307       0.705\n",
      "C(workclass, contrast_mat, levels=l) Federal-gov          1.0792      0.132      8.196      0.000       0.821       1.337\n",
      "C(workclass, contrast_mat, levels=l) Local-gov            0.6012      0.118      5.106      0.000       0.370       0.832\n",
      "C(workclass, contrast_mat, levels=l) Self-emp-inc         1.2921      0.129     10.026      0.000       1.040       1.545\n",
      "C(sex)[T. Male]                                           1.2097      0.045     27.092      0.000       1.122       1.297\n",
      "education_num                                             0.3284      0.008     41.737      0.000       0.313       0.344\n",
      "capital_gain                                              0.0003   1.08e-05     30.770      0.000       0.000       0.000\n",
      "capital_loss                                              0.0007   3.62e-05     20.617      0.000       0.001       0.001\n",
      "hours_per_week                                            0.0287      0.002     18.937      0.000       0.026       0.032\n",
      "=========================================================================================================================\n"
     ]
    }
   ],
   "source": [
    "# 剔除掉不显著的workclass，搭建逻辑回归模型，并训练模型\n",
    "# 定义workclass的类别顺序，数组里的前3个类别为基准类别\n",
    "l = [' ?', ' Never-worked', ' Without-pay', ' State-gov',\n",
    "     ' Self-emp-not-inc', ' Private', ' Federal-gov',\n",
    "     ' Local-gov',  ' Self-emp-inc']\n",
    "# 定义各个类别对应的虚拟变量\n",
    "contrast = np.eye(9, 6, k=-3)\n",
    "# 为每个虚拟变量命名\n",
    "contrast_mat = ContrastMatrix(contrast, l[3:])\n",
    "formula = '''label_code ~ C(workclass, contrast_mat, levels=l)\n",
    "    + C(sex) + education_num + capital_gain\n",
    "    + capital_loss + hours_per_week'''\n",
    "model = sm.Logit.from_formula(formula, data=train_set)\n",
    "re = model.fit()\n",
    "print(re.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization terminated successfully.\n",
      "         Current function value: 0.409236\n",
      "         Iterations 8\n",
      "                           Logit Regression Results                           \n",
      "==============================================================================\n",
      "Dep. Variable:             label_code   No. Observations:                26048\n",
      "Model:                          Logit   Df Residuals:                    26042\n",
      "Method:                           MLE   Df Model:                            5\n",
      "Date:                Mon, 06 Nov 2023   Pseudo R-squ.:                  0.2582\n",
      "Time:                        20:42:18   Log-Likelihood:                -10660.\n",
      "converged:                       True   LL-Null:                       -14370.\n",
      "Covariance Type:            nonrobust   LLR p-value:                     0.000\n",
      "============================================================================================================\n",
      "                                               coef    std err          z      P>|z|      [0.025      0.975]\n",
      "------------------------------------------------------------------------------------------------------------\n",
      "Intercept                                   -6.3756      0.106    -59.909      0.000      -6.584      -6.167\n",
      "C(sex, contrast_mat, levels=l)Ridit(sex)    -1.2221      0.044    -27.602      0.000      -1.309      -1.135\n",
      "education_num                                0.3329      0.008     43.145      0.000       0.318       0.348\n",
      "capital_gain                                 0.0003   1.08e-05     31.073      0.000       0.000       0.000\n",
      "capital_loss                                 0.0007    3.6e-05     20.798      0.000       0.001       0.001\n",
      "hours_per_week                               0.0308      0.001     20.857      0.000       0.028       0.034\n",
      "============================================================================================================\n"
     ]
    }
   ],
   "source": [
    "# 使用sex变量的Ridit scoring搭建逻辑回归模型，并训练模型\n",
    "l = [' Male', ' Female']\n",
    "contrast = [[-0.33], [0.67]]\n",
    "contrast_mat = ContrastMatrix(contrast, ['Ridit(sex)'])\n",
    "formula = '''label_code ~ C(sex, contrast_mat, levels=l) + education_num\n",
    "          + capital_gain + capital_loss + hours_per_week'''\n",
    "model = sm.Logit.from_formula(formula, data=train_set)\n",
    "re = model.fit()\n",
    "print(re.summary())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
